import torch
import torch.nn as nn
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from models.archs.encoders.pointnet2_utils.pointnet2.pointnet2_modules import PointnetFPModule, \
    PointnetSAModuleMSG, PointnetSAModule, PointnetFPModule_new
from models.archs.encoders.pointnet2_utils.pointnet2 import pointnet2_utils
import models.archs.encoders.pointnet2_utils.pointnet2.pytorch_utils as pt_utils
from ipdb import set_trace
import torch.nn.init as init
from torch.nn.modules import dropout
from torch.nn.modules.batchnorm import BatchNorm1d
import torch
import torch.nn.functional as F

def get_model(input_channels=0):
    return Pointnet2MSG(input_channels=input_channels)


MSG_CFG = {
    'NPOINTS': [512, 256, 128, 64],
    'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16]],
    'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32]],
    'MLPS': [[[16, 16, 32], [32, 32, 64]],
             [[64, 64, 128], [64, 96, 128]],
             [[128, 196, 256], [128, 196, 256]],
             [[256, 256, 512], [256, 384, 512]]],
    'FP_MLPS': [[64, 64], [128, 128], [256, 256], [512, 512]],
    'CLS_FC': [128],
    'DP_RATIO': 0.5,
}

ClsMSG_CFG = {
    'NPOINTS': [512, 256, 128, 64, None],
    'RADIUS': [[0.01, 0.02], [0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
    'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]],
    'MLPS': [[[16, 16, 32], [32, 32, 64]],
             [[64, 64, 128], [64, 96, 128]],
             [[128, 196, 256], [128, 196, 256]],
             [[256, 256, 512], [256, 384, 512]],
             [[512, 512], [512, 512]]],
    'DP_RATIO': 0.5,
}

ClsMSG_CFG_Dense = {
    'NPOINTS': [512, 256, 128, None],
    'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
    'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
    'MLPS': [[[16, 16, 32], [32, 32, 64]],
             [[64, 64, 128], [64, 96, 128]],
             [[128, 196, 256], [128, 196, 256]],
             [[256, 256, 512], [256, 384, 512]]],
    'DP_RATIO': 0.5,
}


ClsMSG_CFG_Light = {
    'NPOINTS': [512, 256, 128, None],
    'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
    'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
    'MLPS': [[[16, 16, 32], [32, 32, 64]],
             [[64, 64, 128], [64, 96, 128]],
             [[128, 196, 256], [128, 196, 256]],
             [[256, 256, 512], [256, 384, 512]]],
    'DP_RATIO': 0.5,
}

ClsMSG_CFG_Lighter = {
    'NPOINTS': [512, 256, 128, 64, None],
    'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
    'NSAMPLE': [[64], [32], [16], [8], [None]],
    'MLPS': [[[32, 32, 64]],
             [[64, 64, 128]],
             [[128, 196, 256]],
             [[256, 256, 512]],
             [[512, 512, 1024]]],
    'DP_RATIO': 0.5,
}

SELECTED_PARAMS = ClsMSG_CFG_Light


class Pointnet2MSG(nn.Module):
    def __init__(self, input_channels=6):
        super().__init__()

        self.SA_modules = nn.ModuleList()
        channel_in = input_channels

        skip_channel_list = [input_channels]
        for k in range(MSG_CFG['NPOINTS'].__len__()):
            mlps = MSG_CFG['MLPS'][k].copy()
            channel_out = 0
            for idx in range(mlps.__len__()):
                mlps[idx] = [channel_in] + mlps[idx]
                channel_out += mlps[idx][-1]

            self.SA_modules.append(
                PointnetSAModuleMSG(
                    npoint=MSG_CFG['NPOINTS'][k],
                    radii=MSG_CFG['RADIUS'][k],
                    nsamples=MSG_CFG['NSAMPLE'][k],
                    mlps=mlps,
                    use_xyz=True,
                    bn=True
                )
            )
            skip_channel_list.append(channel_out)
            channel_in = channel_out

        self.FP_modules = nn.ModuleList()

        for k in range(MSG_CFG['FP_MLPS'].__len__()):
            pre_channel = MSG_CFG['FP_MLPS'][k + 1][-1] if k + 1 < len(MSG_CFG['FP_MLPS']) else channel_out
            self.FP_modules.append(
                PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + MSG_CFG['FP_MLPS'][k])
            )

        cls_layers = []
        pre_channel = MSG_CFG['FP_MLPS'][0][-1]
        for k in range(0, MSG_CFG['CLS_FC'].__len__()):
            cls_layers.append(pt_utils.Conv1d(pre_channel, MSG_CFG['CLS_FC'][k], bn=True))
            pre_channel = MSG_CFG['CLS_FC'][k]
        cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None))
        cls_layers.insert(1, nn.Dropout(0.5))
        self.cls_layer = nn.Sequential(*cls_layers)

    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = (
            pc[..., 3:].transpose(1, 2).contiguous()
            if pc.size(-1) > 3 else None
        )

        return xyz, features

    def forward(self, pointcloud: torch.cuda.FloatTensor):
        xyz, features = self._break_up_pc(pointcloud)

        l_xyz, l_features = [xyz], [features]
        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])

            l_xyz.append(li_xyz)
            l_features.append(li_features)

        set_trace()
        for i in range(-1, -(len(self.FP_modules) + 1), -1):
            l_features[i - 1] = self.FP_modules[i](
                l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
            )

        return l_features[0]


class Pointnet2ClsMSG(nn.Module):
    def __init__(self, input_channels=6):
        super().__init__()

        self.SA_modules = nn.ModuleList()
        channel_in = input_channels

        for k in range(SELECTED_PARAMS['NPOINTS'].__len__()):
            mlps = SELECTED_PARAMS['MLPS'][k].copy()
            channel_out = 0
            for idx in range(mlps.__len__()):
                mlps[idx] = [channel_in] + mlps[idx]
                channel_out += mlps[idx][-1]

            self.SA_modules.append(
                PointnetSAModuleMSG(
                    npoint=SELECTED_PARAMS['NPOINTS'][k],
                    radii=SELECTED_PARAMS['RADIUS'][k],
                    nsamples=SELECTED_PARAMS['NSAMPLE'][k],
                    mlps=mlps,
                    use_xyz=True,
                    bn=True
                )
            )
            channel_in = channel_out

    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = (
            pc[..., 3:].transpose(1, 2).contiguous()
            if pc.size(-1) > 3 else None
        )

        return xyz, features

    def forward(self, pointcloud: torch.cuda.FloatTensor):
        xyz, features = self._break_up_pc(pointcloud)

        l_xyz, l_features = [xyz], [features]
        for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
            l_xyz.append(li_xyz)
            l_features.append(li_features)
        return l_features[-1].squeeze(-1)


class FixedPointNetFPModule(PointnetFPModule_new):
    def __init__(self, mlp, bn=True):
        super(FixedPointNetFPModule, self).__init__(mlp, bn)

    def forward(self, unknown, known, unknow_feats, known_feats):
        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight
            )
        else:
            interpolated_feats = known_feats.expand(
                *(list(known_feats.size()[0:2]) + [unknown.size(1)])
            )

        if unknow_feats is not None:
            new_features = torch.cat(
                [interpolated_feats, unknow_feats], dim=1
            )  
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)


class SegNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.sa_module_1 = PointnetSAModule(
            npoint=512,
            radius=0.2,
            nsample=64,
            mlp=[0, 64, 64, 128],
            bn=True,
            use_xyz=True,
        )

        self.sa_module_2 = PointnetSAModule(
            npoint=128,
            radius=0.4,
            nsample=64,
            mlp=[128, 128, 128, 256],
            bn=True,
            use_xyz=True,
        )

        self.sa_module_3 = PointnetSAModule(
            npoint=None,
            radius=None,
            nsample=None,
            mlp=[256, 256, 512, 1024],
            bn=True,
            use_xyz=True,
        )

        self.fp_module_1 = FixedPointNetFPModule(mlp=[256 + 1024, 256, 256])
        self.fp_module_2 = FixedPointNetFPModule(mlp=[128 + 256, 256, 128])
        self.fp_module_3 = FixedPointNetFPModule(mlp=[128, 128, 128, 128])

        self.fc_layer = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Dropout(0.5)
        )

        
        self.grasp_seg_layer = nn.Conv1d(128, num_classes, kernel_size=1, padding=0)

    def forward(self, input):
        l0_xyz = input[:, :, :3]
        
        
        l0_features = None

        l1_xyz, l1_features = self.sa_module_1(l0_xyz, l0_features)
        l2_xyz, l2_features = self.sa_module_2(l1_xyz, l1_features)
        l3_xyz, l3_features = self.sa_module_3(l2_xyz, l2_features)

        l2_features = self.fp_module_1(l2_xyz, l3_xyz, l2_features, l3_features)
        l1_features = self.fp_module_2(l1_xyz, l2_xyz, l1_features, l2_features)
        l0_features = self.fp_module_3(l0_xyz, l1_xyz, l0_features, l1_features)
        total_feature = self.fc_layer(l0_features)
        pred_seg = self.grasp_seg_layer(total_feature).transpose(1, 2)
        
        pred_seg = F.softmax(pred_seg, dim=2)

        return pred_seg


if __name__ == '__main__':
    seed = 100
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    net = Pointnet2ClsMSG(0).cuda()
    pts = torch.randn(2, 1024, 3).cuda()
    print(torch.mean(pts, dim=1))
    pre = net(pts)
    print(pre.shape)
